

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>PyTorch Models &mdash; Causal Discovery Toolbox 0.5.22 documentation</title>
  

  
  <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
  <link rel="stylesheet" href="_static/custom.css" type="text/css" />

  
  
    <link rel="shortcut icon" href="_static/favicon.png"/>
  
  
  

  
  <!--[if lt IE 9]>
    <script src="_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
        <script src="_static/jquery.js"></script>
        <script src="_static/underscore.js"></script>
        <script src="_static/doctools.js"></script>
        <script src="_static/language_data.js"></script>
        <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
        <script type="text/x-mathjax-config">MathJax.Hub.Config({"extensions": ["tex2jax.js"], "jax": ["input/TeX", "output/HTML-CSS"], "tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "displayMath": [["$$", "$$"], ["\\[", "\\]"]], "processEscapes": true}, "HTML-CSS": {"fonts": ["TeX"]}})</script>
    
    <script type="text/javascript" src="_static/js/theme.js"></script>

    
    <link rel="index" title="Index" href="genindex.html" />
    <link rel="search" title="Search" href="search.html" />
    <link rel="next" title="Developer Documentation" href="developer.html" />
    <link rel="prev" title="Toolbox Settings" href="settings.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="index.html">
          

          
            
            <img src="_static/banner.png" class="logo" alt="Logo"/>
          
          </a>

          
            
            
              <div class="version">
                0.5.22
              </div>
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <ul>
<li class="toctree-l1"><a class="reference internal" href="index.html">Causal Discovery Toolbox Documentation</a></li>
</ul>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="tutorial.html">Get started</a></li>
<li class="toctree-l1"><a class="reference internal" href="causality.html">cdt.causality</a></li>
<li class="toctree-l1"><a class="reference internal" href="independence.html">cdt.independence</a></li>
<li class="toctree-l1"><a class="reference internal" href="data.html">cdt.data</a></li>
<li class="toctree-l1"><a class="reference internal" href="utils.html">cdt.utils</a></li>
<li class="toctree-l1"><a class="reference internal" href="metrics.html">cdt.metrics</a></li>
<li class="toctree-l1"><a class="reference internal" href="settings.html">Toolbox Settings</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">PyTorch Models</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#cgnn">CGNN</a></li>
<li class="toctree-l2"><a class="reference internal" href="#sam">SAM</a></li>
<li class="toctree-l2"><a class="reference internal" href="#ncc">NCC</a></li>
<li class="toctree-l2"><a class="reference internal" href="#gnn">GNN</a></li>
<li class="toctree-l2"><a class="reference internal" href="#fsgnn">FSGNN</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="developer.html">Developer Documentation</a></li>
</ul>

            
          
        </div>
        
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="index.html">Causal Discovery Toolbox</a>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        
          















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="index.html" class="icon icon-home"></a> &raquo;</li>
        
      <li>PyTorch Models</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
            
            <a href="_sources/models.rst.txt" rel="nofollow"> View page source</a>
          
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="section" id="pytorch-models">
<h1>PyTorch Models<a class="headerlink" href="#pytorch-models" title="Permalink to this headline">¶</a></h1>
<p>In order to have more flexibility in the use of neural network models,
these are directly assessible as <cite>torch.nn.Module</cite>, using the extensions <cite>.model</cite>, for example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">cdt.causality.graph.model.CGNN</span>
</pre></div>
</div>
<p>to import the CGNN Pytorch model. The available models are the following:</p>
<ul class="simple">
<li><p>CGNN</p></li>
<li><p>SAM</p></li>
<li><p>NCC</p></li>
<li><p>GNN</p></li>
<li><p>FSGNN</p></li>
</ul>
<div class="section" id="cgnn">
<h2>CGNN<a class="headerlink" href="#cgnn" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="cdt.causality.graph.model.CGNN_model">
<em class="property">class </em><code class="sig-prename descclassname">cdt.causality.graph.model.</code><code class="sig-name descname">CGNN_model</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">adj_matrix</span></em>, <em class="sig-param"><span class="n">batch_size</span></em>, <em class="sig-param"><span class="n">nh</span><span class="o">=</span><span class="default_value">20</span></em>, <em class="sig-param"><span class="n">device</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">confounding</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">initial_graph</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="o">**</span><span class="n">kwargs</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/CGNN.html#CGNN_model"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.CGNN_model" title="Permalink to this definition">¶</a></dt>
<dd><p>Class defining the CGNN model.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>adj_matrix</strong> (<em>numpy.array</em>) – Adjacency Matrix of the model to evaluate</p></li>
<li><p><strong>batch_size</strong> (<em>int</em>) – Minibatch size. ~500 is recommended</p></li>
<li><p><strong>nh</strong> (<em>int</em>) – number of hidden units in the hidden layers</p></li>
<li><p><strong>device</strong> (<em>str</em>) – device to which the computation is to be made</p></li>
<li><p><strong>confounding</strong> (<em>bool</em>) – Enables the confounding variant</p></li>
<li><p><strong>initial_graph</strong> (<em>numpy.array</em>) – Initial graph in the confounding case.</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt id="cdt.causality.graph.model.CGNN_model.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/CGNN.html#CGNN_model.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.CGNN_model.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Generate according to the topological order of the graph,
outputs a batch of generated data of size batch_size.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>Generated data</p>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="cdt.causality.graph.model.CGNN_model.run">
<code class="sig-name descname">run</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">dataset</span></em>, <em class="sig-param"><span class="n">train_epochs</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">test_epochs</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">idx</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">lr</span><span class="o">=</span><span class="default_value">0.01</span></em>, <em class="sig-param"><span class="n">dataloader_workers</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="o">**</span><span class="n">kwargs</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/CGNN.html#CGNN_model.run"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.CGNN_model.run" title="Permalink to this definition">¶</a></dt>
<dd><p>Run the CGNN on a given graph.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>dataset</strong> (<em>torch.utils.data.Dataset</em>) – True Data, on the same device as the model.</p></li>
<li><p><strong>train_epochs</strong> (<em>int</em>) – number of train epochs</p></li>
<li><p><strong>test_epochs</strong> (<em>int</em>) – number of test epochs</p></li>
<li><p><strong>verbose</strong> (<em>bool</em>) – verbosity of the model</p></li>
<li><p><strong>idx</strong> (<em>int</em>) – indicator for printing purposes</p></li>
<li><p><strong>lr</strong> (<em>float</em>) – learning rate of the model</p></li>
<li><p><strong>dataloader_workers</strong> (<em>int</em>) – number of workers</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Average score of the graph on <cite>test_epochs</cite> epochs</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>float</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="sam">
<h2>SAM<a class="headerlink" href="#sam" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="cdt.causality.graph.model.SAM_generators">
<em class="property">class </em><code class="sig-prename descclassname">cdt.causality.graph.model.</code><code class="sig-name descname">SAM_generators</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">data_shape</span></em>, <em class="sig-param"><span class="n">nh</span></em>, <em class="sig-param"><span class="n">skeleton</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">linear</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/SAM.html#SAM_generators"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.SAM_generators" title="Permalink to this definition">¶</a></dt>
<dd><p>Ensemble of all the SAM generators.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>data_shape</strong> (<em>tuple</em>) – Shape of the true data</p></li>
<li><p><strong>nh</strong> (<em>int</em>) – Initial number of hidden units in the hidden layers</p></li>
<li><p><strong>skeleton</strong> (<em>numpy.ndarray</em>) – Initial skeleton, defaults to a fully connected graph</p></li>
<li><p><strong>linear</strong> (<em>bool</em>) – Enables the linear variant</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt id="cdt.causality.graph.model.SAM_generators.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">data</span></em>, <em class="sig-param"><span class="n">noise</span></em>, <em class="sig-param"><span class="n">adj_matrix</span></em>, <em class="sig-param"><span class="n">drawn_neurons</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/SAM.html#SAM_generators.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.SAM_generators.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Forward through all the generators.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>data</strong> (<em>torch.Tensor</em>) – True data</p></li>
<li><p><strong>noise</strong> (<em>torch.Tensor</em>) – Samples of noise variables</p></li>
<li><p><strong>adj_matrix</strong> (<em>torch.Tensor</em>) – Sampled adjacency matrix</p></li>
<li><p><strong>drawn_neurons</strong> (<em>torch.Tensor</em>) – Sampled matrix of active neurons</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Batch of generated data</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="cdt.causality.graph.model.SAM_discriminator">
<em class="property">class </em><code class="sig-prename descclassname">cdt.causality.graph.model.</code><code class="sig-name descname">SAM_discriminator</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">nfeatures</span></em>, <em class="sig-param"><span class="n">dnh</span></em>, <em class="sig-param"><span class="n">hlayers</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">mask</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/SAM.html#SAM_discriminator"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.SAM_discriminator" title="Permalink to this definition">¶</a></dt>
<dd><p>SAM discriminator.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>nfeatures</strong> (<em>int</em>) – Number of variables in the dataset</p></li>
<li><p><strong>dnh</strong> (<em>int</em>) – Number of hidden units in the hidden layers</p></li>
<li><p><strong>hlayers</strong> (<em>int</em>) – Number of hidden layers</p></li>
<li><p><strong>mask</strong> (<em>numpy.ndarray</em>) – Mask of connections to ignore</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt id="cdt.causality.graph.model.SAM_discriminator.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input</span></em>, <em class="sig-param"><span class="n">obs_data</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/graph/SAM.html#SAM_discriminator.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.graph.model.SAM_discriminator.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Forward pass in the discriminator.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>input</strong> (<em>torch.Tensor</em>) – True Data or generated data</p></li>
<li><p><strong>obs_data</strong> (<em>torch.Tensor</em>) – True data in the case of <cite>input=generated</cite> for padding.</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Output of the discriminator</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="ncc">
<h2>NCC<a class="headerlink" href="#ncc" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="cdt.causality.pairwise.model.NCC_model">
<em class="property">class </em><code class="sig-prename descclassname">cdt.causality.pairwise.model.</code><code class="sig-name descname">NCC_model</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">n_hiddens</span><span class="o">=</span><span class="default_value">20</span></em>, <em class="sig-param"><span class="n">kernel_size</span><span class="o">=</span><span class="default_value">3</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/pairwise/NCC.html#NCC_model"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.pairwise.model.NCC_model" title="Permalink to this definition">¶</a></dt>
<dd><p>NCC model structure.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>n_hiddens</strong> (<em>int</em>) – Number of hidden features</p></li>
<li><p><strong>kernel_size</strong> (<em>int</em>) – Kernel size of the convolutions</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt id="cdt.causality.pairwise.model.NCC_model.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/pairwise/NCC.html#NCC_model.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.pairwise.model.NCC_model.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Passing data through the network.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>x</strong> (<em>torch.Tensor</em>) – 2d tensor containing both (x,y) Variables</p>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>output of NCC</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="gnn">
<h2>GNN<a class="headerlink" href="#gnn" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="cdt.causality.pairwise.model.GNN_model">
<em class="property">class </em><code class="sig-prename descclassname">cdt.causality.pairwise.model.</code><code class="sig-name descname">GNN_model</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">batch_size</span></em>, <em class="sig-param"><span class="n">nh</span><span class="o">=</span><span class="default_value">20</span></em>, <em class="sig-param"><span class="n">lr</span><span class="o">=</span><span class="default_value">0.01</span></em>, <em class="sig-param"><span class="n">train_epochs</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">test_epochs</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">idx</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">dataloader_workers</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="o">**</span><span class="n">kwargs</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/pairwise/GNN.html#GNN_model"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.pairwise.model.GNN_model" title="Permalink to this definition">¶</a></dt>
<dd><p>Torch model for the GNN structure.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>batch_size</strong> (<em>int</em>) – size of the batch going to be fed to the model</p></li>
<li><p><strong>nh</strong> (<em>int</em>) – Number of hidden units in the hidden layer</p></li>
<li><p><strong>lr</strong> (<em>float</em>) – Learning rate of the Model</p></li>
<li><p><strong>train_epochs</strong> (<em>int</em>) – Number of train epochs</p></li>
<li><p><strong>test_epochs</strong> (<em>int</em>) – Number of test epochs</p></li>
<li><p><strong>idx</strong> (<em>int</em>) – Index (for printing purposes)</p></li>
<li><p><strong>verbose</strong> (<em>bool</em>) – Verbosity of the model</p></li>
<li><p><strong>dataloader_workers</strong> (<em>int</em>) – Number of workers for dataset loading</p></li>
<li><p><strong>device</strong> (<em>str</em>) – device on with the algorithm is going to be run on</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt id="cdt.causality.pairwise.model.GNN_model.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/pairwise/GNN.html#GNN_model.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.pairwise.model.GNN_model.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Pass data through the net structure.
:param x: input data: shape (:,1)
:type x: torch.Tensor</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>Output of the shallow net</p>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="cdt.causality.pairwise.model.GNN_model.run">
<code class="sig-name descname">run</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">dataset</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/causality/pairwise/GNN.html#GNN_model.run"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.causality.pairwise.model.GNN_model.run" title="Permalink to this definition">¶</a></dt>
<dd><p>Run the GNN on a pair x,y of FloatTensor data.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>dataset</strong> (<em>torch.utils.data.Dataset</em>) – True data; First element is the cause</p>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>Score of the configuration</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="fsgnn">
<h2>FSGNN<a class="headerlink" href="#fsgnn" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="cdt.independence.graph.model.FSGNN_model">
<em class="property">class </em><code class="sig-prename descclassname">cdt.independence.graph.model.</code><code class="sig-name descname">FSGNN_model</code><span class="sig-paren">(</span><em class="sig-param">sizes</em>, <em class="sig-param">dropout=0.0</em>, <em class="sig-param">activation_function=&lt;class 'torch.nn.modules.activation.ReLU'&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/independence/graph/FSGNN.html#FSGNN_model"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.independence.graph.model.FSGNN_model" title="Permalink to this definition">¶</a></dt>
<dd><p>Variant of CGNN for feature selection.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>sizes</strong> (<em>list</em>) – Size of the neural network layers</p></li>
<li><p><strong>dropout</strong> (<em>float</em>) – Dropout rate of the neural connections</p></li>
<li><p><strong>activation_function</strong> (<em>torch.nn.Module</em>) – Activation function of the network</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt id="cdt.independence.graph.model.FSGNN_model.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/independence/graph/FSGNN.html#FSGNN_model.forward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.independence.graph.model.FSGNN_model.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Forward pass in the network.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>x</strong> (<em>torch.Tensor</em>) – input data</p>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>output of the network</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>torch.Tensor</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="cdt.independence.graph.model.FSGNN_model.train">
<code class="sig-name descname">train</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">dataset</span></em>, <em class="sig-param"><span class="n">lr</span><span class="o">=</span><span class="default_value">0.01</span></em>, <em class="sig-param"><span class="n">l1</span><span class="o">=</span><span class="default_value">0.1</span></em>, <em class="sig-param"><span class="n">batch_size</span><span class="o">=</span><span class="default_value">- 1</span></em>, <em class="sig-param"><span class="n">train_epochs</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">test_epochs</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">device</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">dataloader_workers</span><span class="o">=</span><span class="default_value">0</span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cdt/independence/graph/FSGNN.html#FSGNN_model.train"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#cdt.independence.graph.model.FSGNN_model.train" title="Permalink to this definition">¶</a></dt>
<dd><p>Train the network and output the scores of the features</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>dataset</strong> (<em>torch.utils.data.Dataset</em>) – Original data</p></li>
<li><p><strong>lr</strong> (<em>float</em>) – Learning rate</p></li>
<li><p><strong>l1</strong> (<em>float</em>) – Coefficient of the L1 regularization</p></li>
<li><p><strong>batch_size</strong> (<em>int</em>) – Batch size of the model, defaults to the dataset size.</p></li>
<li><p><strong>train_epochs</strong> (<em>int</em>) – Number of train epochs</p></li>
<li><p><strong>test_epochs</strong> (<em>int</em>) – Number of test epochs</p></li>
<li><p><strong>device</strong> (<em>str</em>) – Device on which the computation is to be run</p></li>
<li><p><strong>verbose</strong> (<em>bool</em>) – Verbosity of the model</p></li>
<li><p><strong>dataloader_workers</strong> (<em>int</em>) – Number of workers for dataset loading</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>feature selection scores for each feature.</p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>list</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</div>
</div>


           </div>
           
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="developer.html" class="btn btn-neutral float-right" title="Developer Documentation" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
      
      
        <a href="settings.html" class="btn btn-neutral float-left" title="Toolbox Settings" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <p>
        
        &copy; Copyright 2018, Diviyan Kalainathan, Olivier Goudet

    </p>
  </div>
    
    
    
    Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
    
    <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
    
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  

  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
   

</body>
</html>